{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+--------------------+---------+------------------+------------------+------------------+--------------------+-------------------+\n",
      "|               Email|             Address|   Avatar|Avg Session Length|       Time on App|   Time on Website|Length of Membership|Yearly Amount Spent|\n",
      "+--------------------+--------------------+---------+------------------+------------------+------------------+--------------------+-------------------+\n",
      "|mstephenson@ferna...|835 Frank TunnelW...|   Violet| 34.49726772511229| 12.65565114916675| 39.57766801952616|  4.0826206329529615|  587.9510539684005|\n",
      "|   hduke@hotmail.com|4547 Archer Commo...|DarkGreen| 31.92627202636016|11.109460728682564|37.268958868297744|    2.66403418213262|  392.2049334443264|\n",
      "|    pallen@yahoo.com|24645 Valerie Uni...|   Bisque|33.000914755642675|11.330278057777512|37.110597442120856|   4.104543202376424| 487.54750486747207|\n",
      "+--------------------+--------------------+---------+------------------+------------------+------------------+--------------------+-------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "from pyspark.ml.regression import LinearRegression\n",
    "\n",
    "spark = SparkSession.builder.appName('Linear_regression_sample').getOrCreate()\n",
    "data = spark.read.csv('Ecommerce_Customers.csv', header=True, inferSchema=True)\n",
    "data.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Email: string (nullable = true)\n",
      " |-- Address: string (nullable = true)\n",
      " |-- Avatar: string (nullable = true)\n",
      " |-- Avg Session Length: double (nullable = true)\n",
      " |-- Time on App: double (nullable = true)\n",
      " |-- Time on Website: double (nullable = true)\n",
      " |-- Length of Membership: double (nullable = true)\n",
      " |-- Yearly Amount Spent: double (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hduke@hotmail.com\n",
      "4547 Archer CommonDiazchester, CA 06566-8576\n",
      "DarkGreen\n",
      "31.92627202636016\n",
      "11.109460728682564\n",
      "37.268958868297744\n",
      "2.66403418213262\n",
      "392.2049334443264\n"
     ]
    }
   ],
   "source": [
    "for item in data.take(2)[1]:\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Email',\n",
       " 'Address',\n",
       " 'Avatar',\n",
       " 'Avg Session Length',\n",
       " 'Time on App',\n",
       " 'Time on Website',\n",
       " 'Length of Membership',\n",
       " 'Yearly Amount Spent']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.linalg import Vectors\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "\n",
    "data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Row(Email='mstephenson@fernandez.com', Address='835 Frank TunnelWrightmouth, MI 82180-9605', Avatar='Violet', Avg Session Length=34.49726772511229, Time on App=12.65565114916675, Time on Website=39.57766801952616, Length of Membership=4.0826206329529615, Yearly Amount Spent=587.9510539684005, features=DenseVector([34.4973, 12.6557, 39.5777, 4.0826]))]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assembler = VectorAssembler(inputCols = ['Avg Session Length', 'Time on App', 'Time on Website', 'Length of Membership'], outputCol = 'features')\n",
    "output = assembler.transform(data)\n",
    "output.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Email: string (nullable = true)\n",
      " |-- Address: string (nullable = true)\n",
      " |-- Avatar: string (nullable = true)\n",
      " |-- Avg Session Length: double (nullable = true)\n",
      " |-- Time on App: double (nullable = true)\n",
      " |-- Time on Website: double (nullable = true)\n",
      " |-- Length of Membership: double (nullable = true)\n",
      " |-- Yearly Amount Spent: double (nullable = true)\n",
      " |-- features: vector (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "output.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-------------------+\n",
      "|            features|Yearly Amount Spent|\n",
      "+--------------------+-------------------+\n",
      "|[34.4972677251122...|  587.9510539684005|\n",
      "|[31.9262720263601...|  392.2049334443264|\n",
      "|[33.0009147556426...| 487.54750486747207|\n",
      "+--------------------+-------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_data = output.select(['features', 'Yearly Amount Spent'])\n",
    "final_data.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- features: vector (nullable = true)\n",
      " |-- Yearly Amount Spent: double (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_data.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-------------------+\n",
      "|summary|Yearly Amount Spent|\n",
      "+-------+-------------------+\n",
      "|  count|                344|\n",
      "|   mean| 501.56891944874263|\n",
      "| stddev|  80.05283936790765|\n",
      "|    min|   266.086340948469|\n",
      "|    max|  765.5184619388373|\n",
      "+-------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "train, test = final_data.randomSplit([0.7, 0.3])\n",
    "train.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-------------------+\n",
      "|summary|Yearly Amount Spent|\n",
      "+-------+-------------------+\n",
      "|  count|                156|\n",
      "|   mean|  494.3417361469743|\n",
      "| stddev|  77.68504211687474|\n",
      "|    min| 256.67058229005585|\n",
      "|    max|  700.9170916173961|\n",
      "+-------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = LinearRegression(labelCol = 'Yearly Amount Spent')\n",
    "lr_model = lr.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+\n",
      "|          residuals|\n",
      "+-------------------+\n",
      "|  9.920816513034879|\n",
      "|  10.70091171045658|\n",
      "| -4.052157537914752|\n",
      "|  4.197582213270721|\n",
      "|-13.117617411562094|\n",
      "| -7.510566592645546|\n",
      "|  22.01350655518786|\n",
      "|  18.56111122682347|\n",
      "| 3.5394008658557823|\n",
      "|0.23861684488724677|\n",
      "| 2.8285428337596272|\n",
      "| -5.705837609900527|\n",
      "| -4.489310758540171|\n",
      "| -5.753400234926289|\n",
      "|  3.839466127812159|\n",
      "|-2.1099929840607956|\n",
      "|  -4.16714747346856|\n",
      "| -26.51752993610296|\n",
      "| 2.0634824112656247|\n",
      "|-4.8095343970560975|\n",
      "+-------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_result = lr_model.evaluate(test)\n",
    "test_result.residuals.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.935028196912466"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_result.rootMeanSquaredError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.983538996881317"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_result.r2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+-------------------+\n",
      "|summary|Yearly Amount Spent|\n",
      "+-------+-------------------+\n",
      "|  count|                500|\n",
      "|   mean|  499.3140382585909|\n",
      "| stddev|   79.3147815497068|\n",
      "|    min| 256.67058229005585|\n",
      "|    max|  765.5184619388373|\n",
      "+-------+-------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "final_data.describe().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+\n",
      "|            features|\n",
      "+--------------------+\n",
      "|[29.5324289670579...|\n",
      "|[30.7377203726281...|\n",
      "|[30.8364326747734...|\n",
      "|[31.0472221394875...|\n",
      "|[31.0662181616375...|\n",
      "|[31.1280900496166...|\n",
      "|[31.2834474760581...|\n",
      "|[31.3123495994443...|\n",
      "|[31.3662121671876...|\n",
      "|[31.3895854806643...|\n",
      "|[31.4459724827577...|\n",
      "|[31.5147378578019...|\n",
      "|[31.5171218025062...|\n",
      "|[31.5257524169682...|\n",
      "|[31.5316044825729...|\n",
      "|[31.5761319713222...|\n",
      "|[31.6253601348306...|\n",
      "|[31.6739155032749...|\n",
      "|[31.7366356860502...|\n",
      "|[31.7656188210424...|\n",
      "+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "unlabeled_data = test.select('features')\n",
    "unlabeled_data.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+------------------+\n",
      "|            features|        prediction|\n",
      "+--------------------+------------------+\n",
      "|[29.5324289670579...| 398.7195345595926|\n",
      "|[30.7377203726281...| 451.0798304857733|\n",
      "|[30.8364326747734...|471.55405796490436|\n",
      "|[31.0472221394875...| 388.2998169757507|\n",
      "|[31.0662181616375...|462.05091061923645|\n",
      "|[31.1280900496166...| 564.7632533397002|\n",
      "|[31.2834474760581...| 569.7675828704796|\n",
      "|[31.3123495994443...|445.03030680111715|\n",
      "|[31.3662121671876...|427.04948169062914|\n",
      "|[31.3895854806643...|409.83099421509564|\n",
      "|[31.4459724827577...|482.04842210136894|\n",
      "|[31.5147378578019...|495.51832560636194|\n",
      "|[31.5171218025062...| 280.4077314089259|\n",
      "|[31.5257524169682...| 449.7190270448082|\n",
      "|[31.5316044825729...| 432.6761396015504|\n",
      "|[31.5761319713222...| 543.3365769733891|\n",
      "|[31.6253601348306...|380.50404823039275|\n",
      "|[31.6739155032749...|502.24259784598416|\n",
      "|[31.7366356860502...| 494.8699638442663|\n",
      "|[31.7656188210424...|501.36361603266323|\n",
      "+--------------------+------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = lr_model.transform(unlabeled_data)\n",
    "predictions.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
